library(here)
library(cowplot)
source(here("utils/data_processing.R"))
source(here("utils/figures.R"))
all_models <- list.files(here("data/processed_diagnoses"), pattern = "gz$") %>%
str_split("diagnoses_|_icd|.csv") %>%
sapply(., function(x) x[2]) %>%
unique()
all_models
[1] "claude-3-haiku-20240307_t1-0" "claude-3-opus-20240229_t1-0" "gemini-1.0-pro-002_t1-0" "gemini-1.5-flash-preview-0514_t1-0" "gemini-1.5-pro-001_t1-0"
[6] "gpt-3.5-turbo-1106" "gpt-4-turbo-preview"
Import data
df_gpt3.5 <- read_model("gpt-3.5-turbo-1106", icd = FALSE)
df_gpt4.0 <- read_model("gpt-4-turbo-preview", icd = FALSE)
df_claude3_haiku_t1.0 <- read_model("claude-3-haiku-20240307_t1-0", icd = FALSE)
df_claude3_opus_t1.0 <- read_model("claude-3-opus-20240229_t1-0", icd = FALSE)
df_gemini1.0_pro_t1.0 <- read_model("gemini-1.0-pro-002_t1-0", icd = FALSE)
df_gemini1.5_pro_t1.0 <- read_model("gemini-1.5-pro-001_t1-0", icd = FALSE)
df_gpt3.5_icd <- read_model("gpt-3.5-turbo-1106", icd = TRUE)
df_gpt4.0_icd <- read_model("gpt-4-turbo-preview", icd = TRUE)
df_claude3_haiku_t1.0_icd <- read_model("claude-3-haiku-20240307_t1-0", icd = TRUE)
df_claude3_opus_t1.0_icd <- read_model("claude-3-opus-20240229_t1-0", icd = TRUE)
df_gemini1.0_pro_t1.0_icd <- read_model("gemini-1.0-pro-002_t1-0", icd = TRUE)
df_gemini1.5_pro_t1.0_icd <- read_model("gemini-1.5-pro-001_t1-0", icd = TRUE)
Rank abundance
Original responses
rank_abundance_plot(df_gpt3.5)+ggtitle("ChatGPT 3.5")

rank_abundance_plot(df_gpt4.0)+ggtitle("ChatGPT 4.0")

rank_abundance_plot(df_claude3_haiku_t1.0)+ggtitle("Claude3 Haiku t1.0")

rank_abundance_plot(df_claude3_opus_t1.0)+ggtitle("Claude3 Opus")

rank_abundance_plot(df_gemini1.0_pro_t1.0)+ggtitle("Gemini 1.0 Pro")

rank_abundance_plot(df_gemini1.5_pro_t1.0)+ggtitle("Gemini 1.5 Pro")

ICD converted responses
rank_abundance_plot(df_gpt3.5_icd)+ggtitle("ChatGPT 3.5 ICD")

rank_abundance_plot(df_gpt4.0_icd)+ggtitle("ChatGPT 4.0 ICD")

rank_abundance_plot(df_claude3_haiku_t1.0_icd)+ggtitle("Claude3 Haiku ICD")

rank_abundance_plot(df_claude3_opus_t1.0_icd)+ggtitle("Claude3 Opus ICD")

rank_abundance_plot(df_gemini1.0_pro_t1.0_icd)+ggtitle("Gemini 1.0 Pro ICD")

rank_abundance_plot(df_gemini1.5_pro_t1.0_icd)+ggtitle("Gemini 1.5 Pro ICD")

Combined model data
multi_ranked_abundance_plot(df_gpt3.5, df_gpt4.0, df_claude3_haiku_t1.0,
df_claude3_opus_t1.0, df_gemini1.0_pro_t1.0,
df_gemini1.5_pro_t1.0)+
ggtitle("Combined model rank abundance", "Original responses")

multi_ranked_abundance_plot(df_gpt3.5_icd, df_gpt4.0_icd, df_claude3_haiku_t1.0_icd,
df_claude3_opus_t1.0_icd, df_gemini1.0_pro_t1.0_icd,
df_gemini1.5_pro_t1.0_icd)+
ggtitle("Combined model rank abundance", "ICD converted responses")

Top diagnoses plots
custom_labeler <- function(x, wrap_width=33) {
x %>%
str_replace("___.+$", "") %>%
str_wrap(width = wrap_width)
}
custom_text_formatting <- list(
theme(axis.text = element_text(size = 7, lineheight = 0.7),
strip.text = element_text(size = 7),
axis.title = element_text(size = 9)),
tidytext::scale_x_reordered(labels = ~custom_labeler(., wrap_width = 45))
)
n_diag <- 25
sub <- "Original responses"
top_diagnosis_plot(df_gpt3.5, n_diag = n_diag)+ggtitle("ChatGPT 3.5", sub)

top_diagnosis_plot(df_gpt4.0, n_diag = n_diag)+ggtitle("ChatGPT 4.0", sub)

top_diagnosis_plot(df_claude3_haiku_t1.0, n_diag = n_diag)+ggtitle("Claude3 Haiku t1.0", sub)

top_diagnosis_plot(df_claude3_opus_t1.0, n_diag = n_diag)+ggtitle("Claude3 Opus t1.0", sub)

top_diagnosis_plot(df_gemini1.0_pro_t1.0, n_diag = n_diag)+ggtitle("Gemini 1.0 Pro", sub)

top_diagnosis_plot(df_gemini1.5_pro_t1.0, n_diag = n_diag)+ggtitle("Gemini 1.5 Pro", sub)

n_diag <- 25
sub <- "ICD converted responses"
top_diagnosis_plot(df_gpt3.5_icd, n_diag = n_diag) + custom_text_formatting + ggtitle("ChatGPT 3.5 ICD", sub)
Scale for x is already present.
Adding another scale for x, which will replace the existing scale.

top_diagnosis_plot(df_gpt4.0_icd, n_diag = n_diag)+ custom_text_formatting+ggtitle("ChatGPT 4.0 ICD", sub)
Scale for x is already present.
Adding another scale for x, which will replace the existing scale.

top_diagnosis_plot(df_claude3_haiku_t1.0_icd, n_diag = n_diag)+ custom_text_formatting+ggtitle("Claude3 Haiku t1.0 ICD", sub)
Scale for x is already present.
Adding another scale for x, which will replace the existing scale.

top_diagnosis_plot(df_claude3_opus_t1.0_icd, n_diag = n_diag)+ custom_text_formatting+ggtitle("Claude3 Opus t1.0 ICD", sub)
Scale for x is already present.
Adding another scale for x, which will replace the existing scale.

top_diagnosis_plot(df_gemini1.0_pro_t1.0_icd, n_diag = n_diag)+ custom_text_formatting+ggtitle("Gemini 1.0 Pro ICD", sub)
Scale for x is already present.
Adding another scale for x, which will replace the existing scale.

top_diagnosis_plot(df_gemini1.5_pro_t1.0_icd, n_diag = n_diag)+ custom_text_formatting+ggtitle("Gemini 1.5 Pro ICD", sub)
Scale for x is already present.
Adding another scale for x, which will replace the existing scale.

multi_top_diagnosis_plot(distribution_vis = "points", wrap_width=45, n_diag = 25,
df_gpt3.5, df_gpt4.0, df_claude3_haiku_t1.0,
df_claude3_opus_t1.0, df_gemini1.0_pro_t1.0,
df_gemini1.5_pro_t1.0)

plt_diag_icd <- multi_top_diagnosis_plot(distribution_vis = "points", wrap_width = 33, n_diag = 15,
df_gpt3.5_icd, df_gpt4.0_icd, df_claude3_haiku_t1.0_icd,
df_claude3_opus_t1.0_icd, df_gemini1.0_pro_t1.0_icd,
df_gemini1.5_pro_t1.0_icd) +
guides(size = guide_legend(override.aes = list(size = 2)))
plt_diag_icd

plt_diag_icd$data %>%
summarise(freq=mean(freq),.by=c("criteria","diagnosis")) %>%
arrange(criteria, desc(freq))
Cumulative top frequency plots
sub <- "Original responses"
cumulative_frequency_plot(df_gpt3.5)$plot+ggtitle("GPT3", sub)

cumulative_frequency_plot(df_gpt4.0)$plot+ggtitle("GPT4", sub)

cumulative_frequency_plot(df_claude3_haiku_t1.0)$plot+ggtitle("Claude3 Haiku", sub)

cumulative_frequency_plot(df_claude3_opus_t1.0)$plot+ggtitle("Claude3 Haiku", sub)

cumulative_frequency_plot(df_gemini1.0_pro_t1.0)$plot+ggtitle("Gemini Pro 1.0", sub)

cumulative_frequency_plot(df_gemini1.5_pro_t1.0)$plot+ggtitle("Gemini Pro 1.5", sub)

sub <- "ICD converted responses"
cumulative_frequency_plot(df_gpt3.5_icd)$plot+ggtitle("GPT3 ICD", sub)

cumulative_frequency_plot(df_gpt4.0_icd)$plot+ggtitle("GPT4 ICD", sub)

cumulative_frequency_plot(df_claude3_haiku_t1.0_icd)$plot+ggtitle("Claude3 Haiku ICD", sub)

cumulative_frequency_plot(df_claude3_opus_t1.0_icd)$plot+ggtitle("Claude3 Haiku ICD", sub)

cumulative_frequency_plot(df_gemini1.0_pro_t1.0_icd)$plot+ggtitle("Gemini Pro 1.0 ICD", sub)

cumulative_frequency_plot(df_gemini1.5_pro_t1.0_icd)$plot+ggtitle("Gemini Pro 1.0 ICD", sub)

plt_freq <- multi_cumulative_frequency_plot(
n_diagnoses = 25,
distribution_vis = "points",
df_gpt3.5,
df_gpt4.0,
df_claude3_haiku_t1.0,
df_claude3_opus_t1.0,
df_gemini1.0_pro_t1.0,
df_gemini1.5_pro_t1.0
) +
ggtitle("Original responses")
plt_freq

plt_freq$data %>% summarise(freq = mean(total_frequency), .by = "criteria")
plt_freq_icd <- multi_cumulative_frequency_plot(
n_diagnoses = 25,
distribution_vis = "points",
df_gpt3.5_icd,
df_gpt4.0_icd,
df_claude3_haiku_t1.0_icd,
df_claude3_opus_t1.0_icd,
df_gemini1.0_pro_t1.0_icd,
df_gemini1.5_pro_t1.0_icd
) +
ggtitle("ICD converted responses")
plt_freq_icd

plt_freq_icd$data %>% summarise(freq = mean(total_frequency), .by = "criteria")
Diagnosis rank table
diagnosis_rank_table(df_gpt3.5, "mast |mastoc|anaphylaxis") %>% mutate(diagnosis = substr(diagnosis, 1, 60))
diagnosis_rank_table(df_gpt4.0, "mast |mastoc|anaphylaxis") %>% mutate(diagnosis = substr(diagnosis, 1, 60))
diagnosis_rank_table(df_claude3_haiku_t1.0, "mast |mastoc|anaphylaxis") %>% mutate(diagnosis = substr(diagnosis, 1, 60))
diagnosis_rank_table(df_claude3_opus_t1.0, "mast |mastoc|anaphylaxis") %>% mutate(diagnosis = substr(diagnosis, 1, 60))
diagnosis_rank_table(df_gemini1.0_pro_t1.0, "mast |mastoc|anaphylaxis") %>% mutate(diagnosis = substr(diagnosis, 1, 60))
diagnosis_rank_table(df_gemini1.5_pro_t1.0, "mast |mastoc|anaphylaxis") %>% mutate(diagnosis = substr(diagnosis, 1, 60))
diagnosis_rank_table(df_gpt3.5_icd, "mast |mastoc|anaphylaxis") %>% mutate(diagnosis = substr(diagnosis, 1, 60))
diagnosis_rank_table(df_gpt4.0_icd, "mast |mastoc|anaphylaxis") %>% mutate(diagnosis = substr(diagnosis, 1, 60))
diagnosis_rank_table(df_claude3_haiku_t1.0_icd, "mast |mastoc|anaphylaxis") %>% mutate(diagnosis = substr(diagnosis, 1, 60))
diagnosis_rank_table(df_claude3_opus_t1.0_icd, "mast |mastoc|anaphylaxis") %>% mutate(diagnosis = substr(diagnosis, 1, 60))
diagnosis_rank_table(df_gemini1.0_pro_t1.0_icd, "mast |mastoc|anaphylaxis") %>% mutate(diagnosis = substr(diagnosis, 1, 60))
diagnosis_rank_table(df_gemini1.5_pro_t1.0_icd, "mast |mastoc|anaphylaxis") %>% mutate(diagnosis = substr(diagnosis, 1, 60))
multi_diagnosis_rank_table <- function(search_pattern, ...){
listN(...) %>%
lapply(., diagnosis_rank_table, pattern = search_pattern) %>%
mapply(function(x,y) {mutate(x, model=y)}, ., names(.), SIMPLIFY = F) %>%
bind_rows() %>%
pivot_longer(contains(c("mcas","kawasaki","sle","migraine")), names_to = "criteria", values_to = "rank") %>%
filter(grepl("mcas", criteria)) %>%
format_models() %>%
format_criteria() %>%
pivot_wider(names_from = "model", values_from = "rank",names_prefix = "model_") %>%
rowwise() %>%
mutate(mean_rank = round(mean(c_across(contains("model_")), na.rm=T)), 0) %>%
mutate(ranks = paste(c_across(contains("model_")), collapse = ", ")) %>%
mutate(output = str_glue("{mean_rank}\n[{ranks}]")) %>%
select(Diagnosis = diagnosis, criteria, output) %>%
pivot_wider(names_from = "criteria", values_from = "output")
}
rank_table <- multi_diagnosis_rank_table(search_pattern = "T78\\.2 |D47\\.02 |D89\\.41 |D89\\.49 |D89\\.4 ",
df_gpt3.5_icd, df_gpt4.0_icd, df_claude3_haiku_t1.0_icd, df_claude3_opus_t1.0_icd, df_gemini1.0_pro_t1.0_icd, df_gemini1.5_pro_t1.0_icd)
rank_table
rank_table %>%
flextable() %>%
width(width = 30) %>%
align(j = 2:3, align = "center", part = "all")
Diagnosis | MCAS - Consortium | MCAS - Alternative |
|---|
T78.2 Anaphylactic shock, unspecified | 1 [1, 1, 1, 1, 1, 1] | 132 [216, 87, 99, 174, 141, 77] |
D47.02 Systemic mastocytosis | 9 [22, 8, 7, 2, 14, 2] | 50 [92, 78, 25, 41, 46, 19] |
D89.41 Monoclonal mast cell activation syndrome | 70 [128, 22, 28, 11, 179, 51] | 74 [141, 64, 22, 37, 168, 12] |
D89.49 Other mast cell activation disorder | 234 [308, 62, 101, 140, 478, 318] | 625 [1155, 568, 178, 467, 1109, 275] |
D89.4 Mast cell activation syndrome and related disorders | 496 [726, 174, NA, NA, 605, 478] | 1423 [NA, 833, 1850, 1594, 1906, 933] |
Diversity
multi_shannon_plot(
distribution_vis = "points",
wrap_width = 45,
n_diag = 25,
df_gpt3.5,
df_gpt4.0,
df_claude3_haiku_t1.0,
df_claude3_opus_t1.0,
df_gemini1.0_pro_t1.0,
df_gemini1.5_pro_t1.0
)

plt_div_icd <- multi_shannon_plot(
distribution_vis = "points",
wrap_width = 45,
n_diag = 25,
df_gpt3.5_icd,
df_gpt4.0_icd,
df_claude3_haiku_t1.0_icd,
df_claude3_opus_t1.0_icd,
df_gemini1.0_pro_t1.0_icd,
df_gemini1.5_pro_t1.0_icd
)
plt_div_icd

plt_div_icd$data %>% summarise(shannon=mean(shannon),.by="criteria")
extract_ggpubr_pvalues(plt_div_icd)
Similarity
diagnosis_similarity_heatmap(df_gpt3.5, method = "bray")

diagnosis_similarity_heatmap(df_gpt4.0, method = "bray")

diagnosis_similarity_heatmap(df_claude3_haiku_t1.0, method = "bray")

diagnosis_similarity_heatmap(df_claude3_opus_t1.0, method = "bray")

diagnosis_similarity_heatmap(df_gemini1.0_pro_t1.0, method = "bray")

diagnosis_similarity_heatmap(df_gemini1.5_pro_t1.0, method = "bray")

diagnosis_similarity_heatmap(df_gpt3.5_icd, method = "bray")

diagnosis_similarity_heatmap(df_gpt4.0_icd, method = "bray")

diagnosis_similarity_heatmap(df_claude3_haiku_t1.0_icd, method = "bray")

diagnosis_similarity_heatmap(df_claude3_opus_t1.0_icd, method = "bray")

diagnosis_similarity_heatmap(df_gemini1.0_pro_t1.0_icd, method = "bray")

diagnosis_similarity_heatmap(df_gemini1.5_pro_t1.0_icd, method = "bray")

multi_diagnosis_similarity_heatmap(
method = "bray",
show_dend = F,
label_size = 6,
title_size = 9,
df_gpt3.5,
df_gpt4.0,
df_claude3_haiku_t1.0,
df_claude3_opus_t1.0,
df_gemini1.0_pro_t1.0,
df_gemini1.5_pro_t1.0
)

multi_diagnosis_similarity_heatmap(
method = "bray",
show_dend = F,
label_size = 6,
title_size = 9,
df_gpt3.5_icd,
df_gpt4.0_icd,
df_claude3_haiku_t1.0_icd,
df_claude3_opus_t1.0_icd,
df_gemini1.0_pro_t1.0_icd,
df_gemini1.5_pro_t1.0_icd
)

- Bray-Curtis similarity measures the similarity of a given diagnostic
criteria’s set of alternative diagnoses along with their
frequencies.
- This demonstrates that SLE criteria results in a very similar set
and frequency of diagnoses, while the diagnoses associated with two MCAS
criteria are as different from each other as they are from those
generated by the criteria of other conditions.
PCA
diagnosis_pca_plot(df_gpt3.5) + ggtitle("GPT3")

diagnosis_pca_plot(df_gpt4.0) + ggtitle("GPT4")

diagnosis_pca_plot(df_claude3_haiku_t1.0) + ggtitle("Claude Haiku")

diagnosis_pca_plot(df_claude3_opus_t1.0) + ggtitle("Claude Opus")

diagnosis_pca_plot(df_gemini1.0_pro_t1.0) + ggtitle("Gemini")

diagnosis_pca_plot(df_gemini1.5_pro_t1.0) + ggtitle("Gemini")

df <- listN(df_gpt3.5_icd, df_gpt4.0_icd, df_claude3_haiku_t1.0_icd, df_claude3_opus_t1.0_icd, df_gemini1.0_pro_t1.0_icd, df_gemini1.5_pro_t1.0_icd) %>%
mapply(function(x,y) {mutate(x, model=y)}, ., names(.), SIMPLIFY = F) %>%
bind_rows() %>%
count(model, criteria, diagnosis) %>%
pivot_wider(names_from = "diagnosis", values_from = "n", values_fill = 0) %>%
unite(id, model, criteria, sep = "__") %>%
column_to_rownames("id") %>%
prcomp(scale. = F)
as.data.frame(df$x) %>%
rownames_to_column("id") %>%
separate(id, into = c("model", "criteria"), sep = "__") %>%
format_criteria() %>%
format_models() %>%
ggplot(aes(x = PC1, y = PC2, color = criteria))+
geom_point()+
# ggrepel::geom_label_repel() +
theme_bw() +
scale_color_brewer(palette = "Dark2")

Precision
- Precision represents how similar each iteration of a 10-point
differential diagnosis is with all other differential diagnoses from the
same set of criteria.
- I.e. how reproducible the 10-point differential diagnosis is for
each criteria
- Measured by obtaining the Bray-Curtis similarity values between all
iterations within a criteria
# Script for calculating all Bray-Curtis similarity values within a criteria
# Found in source(here("scripts/diversity_analysis/calculate_precision.R"))
# Calculate precision
library(here)
source(here("utils/data_processing.R"))
models <- list.files(here("data/processed_diagnoses"), pattern = "gz$") %>%
str_split("diagnoses_|_icd|.csv") %>%
sapply(., function(x) x[2]) %>%
unique()
use_icd <- TRUE
if (use_icd){models <- str_glue("{models}_icd")}
for (m in models){
print(sprintf("READING IN DATA FOR: %s", m))
read_path <- sprintf("data/processed_diagnoses/diagnoses_%s.csv.gz", m)
df <- read_csv(here(read_path))
print(sprintf("CALCULATING PRECISION FOR: %s", m))
df <- calculate_precision(df)
print(sprintf("WRITING PRECISION DATA FOR: %s", m))
out_path <- sprintf("data/diversity_analysis/diagnosis_precision_%s.csv.gz", m)
write_csv(df, here(out_path))
}
precision_dist_to_sim <- function(df){
df %>%
mutate(
mean = 1-mean,
max = 1-min,
min = 1-max
)
}
plt_precision_icd <- read_csv(here("data/diversity_analysis/compiled_icd_diagnosis_precision.csv")) %>%
precision_dist_to_sim() %>%
format_criteria() %>%
format_models() %>%
filter(model != "Gemini 1.5 Flash") %>%
ggplot(aes(x = criteria, y = mean))+
theme_bw()+
theme(axis.text.x = element_text(angle= 45, hjust = 1))+
labs(x="", y = "Average Bray-Curtis Similarity") +
ggpubr::geom_pwc(aes(group = criteria), method = "wilcox.test", p.adjust.method = "BH", hide.ns = T, label = "p.adj.signif", bracket.nudge.y = 0.3, vjust = 0.6, step.increase = 0.14, tip.length = 0.02) +
labs(color=NULL)+
scale_color_brewer(palette = "Dark2") +
plot_selector("points")
Rows: 42 Columns: 8── Column specification ─────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
Delimiter: ","
chr (2): criteria, model
dbl (6): n, mean, max, min, sd, se
ℹ Use `spec()` to retrieve the full column specification for this data.
ℹ Specify the column types or set `show_col_types = FALSE` to quiet this message.Scale for colour is already present.
Adding another scale for colour, which will replace the existing scale.
plt_precision_icd

plt_precision_icd$data %>% summarise(mean = mean(mean), .by="criteria")
extract_ggpubr_pvalues(plt_precision_icd)
iNEXT
inext_plots <- function(inext_obj){
for (i in 1:3){
plt <- iNEXT::ggiNEXT(inext_obj, type=i, facet.var="Assemblage", color.var="Assemblage") +
theme_classic() +
scale_color_brewer(palette = "Set1") +
theme(axis.text.x = element_text(angle = 90))+
scale_color_brewer(palette = "Dark2")
print(plt)
}
}
readRDS(here("data/diversity_analysis/mcas_iNEXT_gpt4_e250000.RDS")) %>% inext_plots()
Scale for colour is already present.
Adding another scale for colour, which will replace the existing scale.Scale for colour is already present.
Adding another scale for colour, which will replace the existing scale.



readRDS(here("data/diversity_analysis/mcas_iNEXT_dropSingle_gpt4_e200000.RDS")) %>% inext_plots()
Scale for colour is already present.
Adding another scale for colour, which will replace the existing scale.Scale for colour is already present.
Adding another scale for colour, which will replace the existing scale.



readRDS(here("data/diversity_analysis/mcas_iNEXT_dropSingle_psuedoMinus_gpt4_e200000.RDS")) %>% inext_plots()
Scale for colour is already present.
Adding another scale for colour, which will replace the existing scale.Scale for colour is already present.
Adding another scale for colour, which will replace the existing scale.



# custom_labeler <- function(x, wrap_width=33) {
# x %>%
# str_replace("___.+$", "") %>%
# str_wrap(width = wrap_width)
# }
Final plot
Version 1
n_diagnoses_bar <- 10
n_diagnoses_abundance <- 50
n_diagnoses_cumulative <- 50
title_size <- 9
label_size <- 6
legend_x_pad <- 4
legend_y_pad <- 2
apply_text_formatting <- list(theme(
axis.text = element_text(size = label_size),
axis.title = element_text(size = title_size),
legend.text = element_text(size = label_size),
strip.text = element_text(size = label_size+1),
legend.key.height = unit(0.4, 'cm'),
legend.box.background = element_rect(color = "black", size = 1),
legend.margin = margin(t = legend_y_pad, r = legend_x_pad, b = legend_y_pad, l = legend_x_pad*1.1),
legend.spacing.x = unit(0, 'cm'), # Horizontal spacing between legend items
# legend.spacing.y = unit(0, 'cm'),
# legend.box.spacing = unit(0, "cm")
))
strip_margin <- 1
strip_formatting <- list(theme(
strip.text.x = element_text(margin = margin(t = strip_margin, r = strip_margin, b = strip_margin, l = strip_margin)),
strip.text.y = element_text(margin = margin(t = strip_margin, r = strip_margin, b = strip_margin, l = strip_margin))
# strip.background = element_rect(margin = margin(t = strip_margin, r = strip_margin, b = strip_margin, l = strip_margin))
))
plt_diags <-
multi_top_diagnosis_plot(
distribution_vis = "points",
wrap_width = 58,
n_diag = n_diagnoses_bar,
df_gpt3.5_icd,
df_gpt4.0_icd,
df_claude3_haiku_t1.0_icd,
df_claude3_opus_t1.0_icd,
df_gemini1.0_pro_t1.0_icd,
df_gemini1.5_flash_t1.0_icd
) +
theme(legend.position = "bottom", legend.direction = "horizontal") +
apply_text_formatting +
theme(axis.text.y = element_text(size = 6.5)) +
strip_formatting +
# theme(legend.position = c(-1,0))+
theme(panel.spacing = unit(0, "lines")) +
guides(color = guide_legend(override.aes = list(size = 2))) # Increase the point size in the legend)
plt_rank <-
multi_ranked_abundance_plot(
df_gpt3.5_icd,
df_gpt4.0_icd,
df_claude3_haiku_t1.0_icd,
df_claude3_opus_t1.0_icd,
df_gemini1.0_pro_t1.0_icd,
df_gemini1.5_pro_t1.0_icd
) +
theme(legend.position = "bottom", legend.direction = "horizontal") +
apply_text_formatting +
guides(color = guide_legend(ncol = 2))
plt_cumulative <- multi_cumulative_frequency_plot(
n_diagnoses = n_diagnoses_cumulative,
distribution_vis = "points",
df_gpt3.5_icd,
df_gpt4.0_icd,
df_claude3_haiku_t1.0_icd,
df_claude3_opus_t1.0_icd,
df_gemini1.0_pro_t1.0_icd
) +
theme(legend.position = "bottom", legend.direction = "horizontal") +
apply_text_formatting +
guides(color = guide_legend(ncol = 2)) +
labs(y = "Combined frequency\nof top 50 diagnoses", x = NULL)
plt_shannon <- multi_shannon_plot(
distribution_vis = "points",
wrap_width = 45,
n_diag = 25,
df_gpt3.5_icd,
df_gpt4.0_icd,
df_claude3_haiku_t1.0_icd,
df_claude3_opus_t1.0_icd,
df_gemini1.0_pro_t1.0_icd,
df_gemini1.5_pro_t1.0_icd
) +
apply_text_formatting +
theme(legend.position = "bottom", legend.direction = "horizontal") +
guides(color = guide_legend(ncol = 2))
plt_precision <- read_csv(here("data/diversity_analysis/compiled_icd_diagnosis_precision.csv")) %>%
precision_dist_to_sim() %>%
format_criteria() %>%
format_models() %>%
filter(model != "Gemini 1.5 Flash") %>%
ggplot(aes(x = criteria, y = mean))+
theme_bw()+
theme(axis.text.x = element_text(angle= 45, hjust = 1))+
labs(x="", y = "Average Bray-Curtis\nSimilarity") +
ggpubr::geom_pwc(aes(group = criteria), method = "wilcox.test", p.adjust.method = "BH", hide.ns = T, label = "p.adj.signif", bracket.nudge.y = 0.3, vjust = 0.6, step.increase = 0.14, tip.length = 0.02) +
labs(color=NULL)+
scale_color_brewer(palette = "Dark2") +
plot_selector("points") +
apply_text_formatting +
theme(legend.position = "bottom", legend.direction = "horizontal") +
guides(color = guide_legend(ncol = 2))
full_plt <- plot_grid(
###
plt_diags,
###
NULL,
plot_grid(
plt_rank,
plt_cumulative,
plt_shannon,
plt_precision,
nrow = 1,
axis = 'tb',
align = 'h',
rel_widths = c(1, 0.7, 0.7, 0.7),
labels = c(LETTERS[2:5]),
vjust = 0.2
),
ncol = 1,
rel_heights = c(1.2, 0.05, 0.65),
labels = c("A","","")
)
full_plt

Version 2
n_diagnoses_bar <- 10
n_diagnoses_abundance <- 50
n_diagnoses_cumulative <- 50
title_size <- 9
label_size <- 6
legend_x_pad <- 4
legend_y_pad <- 2
apply_text_formatting <- list(theme(
axis.text = element_text(size = label_size),
axis.title = element_text(size = title_size),
legend.text = element_text(size = label_size),
strip.text = element_text(size = label_size+1),
legend.key.height = unit(0.4, 'cm'),
legend.box.background = element_rect(color = "black", size = 1),
legend.margin = margin(t = legend_y_pad, r = legend_x_pad, b = legend_y_pad, l = legend_x_pad*1.1),
legend.spacing.x = unit(0, 'cm'), # Horizontal spacing between legend items
# legend.spacing.y = unit(0, 'cm'),
# legend.box.spacing = unit(0, "cm")
))
strip_margin <- 1
strip_formatting <- list(theme(
strip.text.x = element_text(margin = margin(t = strip_margin, r = strip_margin, b = strip_margin, l = strip_margin)),
strip.text.y = element_text(margin = margin(t = strip_margin, r = strip_margin, b = strip_margin, l = strip_margin))
# strip.background = element_rect(margin = margin(t = strip_margin, r = strip_margin, b = strip_margin, l = strip_margin))
))
plt_diags <-
multi_top_diagnosis_plot(
distribution_vis = "points",
wrap_width = 58,
n_diag = n_diagnoses_bar,
df_gpt3.5_icd,
df_gpt4.0_icd,
df_claude3_haiku_t1.0_icd,
df_claude3_opus_t1.0_icd,
df_gemini1.0_pro_t1.0_icd,
df_gemini1.5_flash_t1.0_icd
) +
theme(legend.position = "bottom", legend.direction = "horizontal") +
apply_text_formatting +
theme(axis.text.y = element_text(size = 6.5)) +
strip_formatting +
# theme(legend.position = c(-1,0))+
theme(panel.spacing = unit(0, "lines")) +
guides(color = guide_legend(override.aes = list(size = 2), nrow = 1)) # Increase the point size in the legend)
plt_rank <-
multi_ranked_abundance_plot(
df_gpt3.5_icd,
df_gpt4.0_icd,
df_claude3_haiku_t1.0_icd,
df_claude3_opus_t1.0_icd,
df_gemini1.0_pro_t1.0_icd,
df_gemini1.5_pro_t1.0_icd
) +
theme(legend.position = c(0.7,0.7))+
# theme(legend.position = "bottom", legend.direction = "horizontal") +
apply_text_formatting +
guides(color = guide_legend(ncol = 1)) +
labs(color = NULL)
plt_cumulative <- multi_cumulative_frequency_plot(
n_diagnoses = n_diagnoses_cumulative,
distribution_vis = "points",
df_gpt3.5_icd,
df_gpt4.0_icd,
df_claude3_haiku_t1.0_icd,
df_claude3_opus_t1.0_icd,
df_gemini1.0_pro_t1.0_icd
) +
theme(legend.position = "bottom", legend.direction = "horizontal") +
apply_text_formatting +
guides(color = guide_legend(ncol = 2)) +
labs(y = "Combined frequency\nof top 50 diagnoses", x = NULL)
plt_shannon <- multi_shannon_plot(
distribution_vis = "points",
wrap_width = 45,
n_diag = 25,
df_gpt3.5_icd,
df_gpt4.0_icd,
df_claude3_haiku_t1.0_icd,
df_claude3_opus_t1.0_icd,
df_gemini1.0_pro_t1.0_icd,
df_gemini1.5_pro_t1.0_icd
) +
apply_text_formatting +
theme(legend.position = "bottom", legend.direction = "horizontal") +
guides(color = guide_legend(ncol = 2))
plt_precision <- read_csv(here("data/diversity_analysis/compiled_icd_diagnosis_precision.csv")) %>%
precision_dist_to_sim() %>%
format_criteria() %>%
format_models() %>%
filter(model != "Gemini 1.5 Flash") %>%
ggplot(aes(x = criteria, y = mean))+
theme_bw()+
theme(axis.text.x = element_text(angle= 45, hjust = 1))+
labs(x="", y = "Average Bray-Curtis\nSimilarity") +
ggpubr::geom_pwc(aes(group = criteria), method = "wilcox.test", p.adjust.method = "BH", hide.ns = T, label = "p.adj.signif", bracket.nudge.y = 0.3, vjust = 0.6, step.increase = 0.14, tip.length = 0.02) +
labs(color=NULL)+
scale_color_brewer(palette = "Dark2") +
plot_selector("points") +
apply_text_formatting +
theme(legend.position = "bottom", legend.direction = "horizontal") +
guides(color = guide_legend(ncol = 2))
full_plt <- plot_grid(
###
plt_diags,
###
NULL,
plot_grid(
plt_rank,
plot_grid(
plot_grid(
plt_shannon+ theme(legend.position="none"),
plt_precision+ theme(legend.position="none"),
nrow = 1,
axis = 'tb',
align = 'h'
),
get_legend(plt_shannon+ guides(color = guide_legend(row = 1))),
ncol = 1,
rel_heights = c(1,0.1)
),
nrow = 1,
rel_widths = c(1,1),
# labels = c(LETTERS[2:5]),
vjust = 0.2
),
ncol = 1,
rel_heights = c(1.2, 0.05, 0.65),
labels = c("A","","")
)
full_plt

Version 3
n_diagnoses_bar <- 10
n_diagnoses_abundance <- 50
n_diagnoses_cumulative <- 50
title_size <- 9
label_size <- 6
legend_x_pad <- 4
legend_y_pad <- 2
apply_text_formatting <- list(theme(
axis.text = element_text(size = label_size),
axis.title = element_text(size = title_size),
legend.text = element_text(size = label_size),
strip.text = element_text(size = label_size+1),
legend.key.height = unit(0.4, 'cm'),
legend.box.background = element_rect(color = "black", size = 1),
legend.margin = margin(t = legend_y_pad, r = legend_x_pad, b = legend_y_pad, l = legend_x_pad*1.1),
legend.spacing.x = unit(0, 'cm'), # Horizontal spacing between legend items
# legend.spacing.y = unit(0, 'cm'),
# legend.box.spacing = unit(0, "cm")
))
strip_margin <- 1
strip_formatting <- list(theme(
strip.text.x = element_text(margin = margin(t = strip_margin, r = strip_margin, b = strip_margin, l = strip_margin)),
strip.text.y = element_text(margin = margin(t = strip_margin, r = strip_margin, b = strip_margin, l = strip_margin))
# strip.background = element_rect(margin = margin(t = strip_margin, r = strip_margin, b = strip_margin, l = strip_margin))
))
plt_diags <-
multi_top_diagnosis_plot(
distribution_vis = "points",
wrap_width = 58,
n_diag = n_diagnoses_bar,
df_gpt3.5_icd,
df_gpt4.0_icd,
df_claude3_haiku_t1.0_icd,
df_claude3_opus_t1.0_icd,
df_gemini1.0_pro_t1.0_icd,
df_gemini1.5_flash_t1.0_icd
) +
theme(legend.position = "bottom", legend.direction = "horizontal") +
apply_text_formatting +
theme(axis.text.y = element_text(size = 6.5)) +
strip_formatting +
# theme(legend.position = c(-1,0))+
theme(panel.spacing = unit(0, "lines")) +
guides(color = guide_legend(override.aes = list(size = 2), nrow = 1)) # Increase the point size in the legend)
plt_rank <-
multi_ranked_abundance_plot(
df_gpt3.5_icd,
df_gpt4.0_icd,
df_claude3_haiku_t1.0_icd,
df_claude3_opus_t1.0_icd,
df_gemini1.0_pro_t1.0_icd,
df_gemini1.5_pro_t1.0_icd
) +
theme(legend.position = c(0.7,0.7))+
# theme(legend.position = "bottom", legend.direction = "horizontal") +
apply_text_formatting +
guides(color = guide_legend(ncol = 1)) +
labs(color = NULL)
plt_cumulative <- multi_cumulative_frequency_plot(
n_diagnoses = n_diagnoses_cumulative,
distribution_vis = "points",
df_gpt3.5_icd,
df_gpt4.0_icd,
df_claude3_haiku_t1.0_icd,
df_claude3_opus_t1.0_icd,
df_gemini1.0_pro_t1.0_icd
) +
theme(legend.position = "bottom", legend.direction = "horizontal") +
apply_text_formatting +
guides(color = guide_legend(ncol = 2)) +
labs(y = "Combined frequency\nof top 50 diagnoses", x = NULL)
plt_shannon <- multi_shannon_plot(
distribution_vis = "points",
wrap_width = 45,
n_diag = 25,
df_gpt3.5_icd,
df_gpt4.0_icd,
df_claude3_haiku_t1.0_icd,
df_claude3_opus_t1.0_icd,
df_gemini1.0_pro_t1.0_icd,
df_gemini1.5_pro_t1.0_icd
) +
apply_text_formatting +
theme(legend.position = "bottom", legend.direction = "horizontal") +
guides(color = guide_legend(ncol = 2))
plt_precision <- read_csv(here("data/diversity_analysis/compiled_icd_diagnosis_precision.csv")) %>%
precision_dist_to_sim() %>%
format_criteria() %>%
format_models() %>%
filter(model != "Gemini 1.5 Flash") %>%
ggplot(aes(x = criteria, y = mean))+
theme_bw()+
theme(axis.text.x = element_text(angle= 45, hjust = 1))+
labs(x="", y = "Mean Bray-Curtis Similarity") +
ggpubr::geom_pwc(aes(group = criteria), method = "wilcox.test", p.adjust.method = "BH", hide.ns = T, label = "p.adj.signif", bracket.nudge.y = 0.3, vjust = 0.6, step.increase = 0.14, tip.length = 0.02) +
labs(color=NULL)+
scale_color_brewer(palette = "Dark2") +
plot_selector("points") +
apply_text_formatting +
theme(legend.position = "bottom", legend.direction = "horizontal") +
guides(color = guide_legend(ncol = 2))
plt_similarity <- multi_diagnosis_similarity_heatmap(
method = "bray",
show_dend = F,
legend_label = "Bray-Curtis similarity",
legend_direction = "horizontal",
label_size = 6,
title_size = 9,
df_gpt3.5_icd,
df_gpt4.0_icd,
df_claude3_haiku_t1.0_icd,
df_claude3_opus_t1.0_icd,
df_gemini1.0_pro_t1.0_icd,
df_gemini1.5_pro_t1.0_icd
)
full_plt <- plot_grid(
###
plt_diags,
###
NULL,
plot_grid(
grid::grid.grabExpr(ComplexHeatmap::draw(plt_similarity, heatmap_legend_side = 'bottom')),
plt_rank,
# NULL,
plot_grid(
plot_grid(
plt_shannon+ theme(legend.position="none"),
plt_precision+ theme(legend.position="none"),
nrow = 1,
axis = 'tb',
align = 'h',
labels = c(LETTERS[4:5])
),
get_legend(plt_shannon+ guides(color = guide_legend(row = 1))),
ncol = 1,
rel_heights = c(1,0.1)
),
nrow = 1,
# rel_widths = c(1, 0.01, 0.8, 0.9),
rel_widths = c(0.8, 1,0.9),
labels = c(LETTERS[2:3]),
vjust = 0.2
),
ncol = 1,
rel_heights = c(1.2, 0.05, 0.65),
labels = c("A","","")
)
full_plt

Things to fix - Legend position for C-E - Legend width for B - Move
legend for A to the left of “Frequency?” - Rank plot line weight
ggsave(plot=full_plt,filename=here("figures/3_diagnosis_diversity.pdf"), width = 7.5, height = 7.5)
set_table_properties(opts_pdf = list(tabcolsep = 0))
set_flextable_defaults(fonts_ignore=TRUE)
multi_diagnosis_rank_table(search_pattern = "T78\\.2 |D47\\.02 |D89\\.41 |D89\\.49 |D89\\.4 ",
df_gpt3.5_icd, df_gpt4.0_icd, df_claude3_haiku_t1.0_icd, df_claude3_opus_t1.0_icd, df_gemini1.0_pro_t1.0_icd) %>%
flextable() %>%
width(width = 2) %>%
fontsize(size = 9) %>%
fontsize(size = 10, part = "header") %>%
padding(padding = 0) %>%
align(j = 2:3, align = "center", part = "all") %>%
set_table_properties(opts_pdf = list(arraystretch = 1.25)) %>%
{print(., preview = "pdf");.}
---
title: "Diagnosis distribution analysis"
output: 
  html_notebook:
    toc: true
    toc_float: true
---

```{r, message = F}
library(here)
library(cowplot)
source(here("utils/data_processing.R"))
source(here("utils/figures.R"))
```

```{r}
all_models <- list.files(here("data/processed_diagnoses"), pattern = "gz$") %>% 
  str_split("diagnoses_|_icd|.csv") %>% 
  sapply(., function(x) x[2]) %>% 
  unique()
all_models
```
# Import data

```{r, message = F}
df_gpt3.5 <- read_model("gpt-3.5-turbo-1106", icd = FALSE)
df_gpt4.0 <- read_model("gpt-4-turbo-preview", icd = FALSE)
df_claude3_haiku_t1.0 <- read_model("claude-3-haiku-20240307_t1-0", icd = FALSE)
df_claude3_opus_t1.0 <- read_model("claude-3-opus-20240229_t1-0", icd = FALSE)
df_gemini1.0_pro_t1.0 <- read_model("gemini-1.0-pro-002_t1-0", icd = FALSE)
df_gemini1.5_pro_t1.0 <- read_model("gemini-1.5-pro-001_t1-0", icd = FALSE)
```

```{r, message = F}
df_gpt3.5_icd <- read_model("gpt-3.5-turbo-1106", icd = TRUE)
df_gpt4.0_icd <- read_model("gpt-4-turbo-preview", icd = TRUE)
df_claude3_haiku_t1.0_icd <- read_model("claude-3-haiku-20240307_t1-0", icd = TRUE)
df_claude3_opus_t1.0_icd <- read_model("claude-3-opus-20240229_t1-0", icd = TRUE)
df_gemini1.0_pro_t1.0_icd <- read_model("gemini-1.0-pro-002_t1-0", icd = TRUE)
df_gemini1.5_pro_t1.0_icd <- read_model("gemini-1.5-pro-001_t1-0", icd = TRUE)
```


# Rank abundance

**Original responses**
```{r}
rank_abundance_plot(df_gpt3.5)+ggtitle("ChatGPT 3.5")
rank_abundance_plot(df_gpt4.0)+ggtitle("ChatGPT 4.0")
rank_abundance_plot(df_claude3_haiku_t1.0)+ggtitle("Claude3 Haiku t1.0")
rank_abundance_plot(df_claude3_opus_t1.0)+ggtitle("Claude3 Opus")
rank_abundance_plot(df_gemini1.0_pro_t1.0)+ggtitle("Gemini 1.0 Pro")
rank_abundance_plot(df_gemini1.5_pro_t1.0)+ggtitle("Gemini 1.5 Pro")
```

**ICD converted responses**
```{r}
rank_abundance_plot(df_gpt3.5_icd)+ggtitle("ChatGPT 3.5 ICD")
rank_abundance_plot(df_gpt4.0_icd)+ggtitle("ChatGPT 4.0 ICD")
rank_abundance_plot(df_claude3_haiku_t1.0_icd)+ggtitle("Claude3 Haiku ICD")
rank_abundance_plot(df_claude3_opus_t1.0_icd)+ggtitle("Claude3 Opus ICD")
rank_abundance_plot(df_gemini1.0_pro_t1.0_icd)+ggtitle("Gemini 1.0 Pro ICD")
rank_abundance_plot(df_gemini1.5_pro_t1.0_icd)+ggtitle("Gemini 1.5 Pro ICD")
```

**Combined model data**

```{r}
multi_ranked_abundance_plot(df_gpt3.5, df_gpt4.0, df_claude3_haiku_t1.0, 
                            df_claude3_opus_t1.0, df_gemini1.0_pro_t1.0,
                            df_gemini1.5_pro_t1.0)+
  ggtitle("Combined model rank abundance", "Original responses")
```

```{r}
multi_ranked_abundance_plot(df_gpt3.5_icd, df_gpt4.0_icd, df_claude3_haiku_t1.0_icd, 
                            df_claude3_opus_t1.0_icd, df_gemini1.0_pro_t1.0_icd, 
                            df_gemini1.5_pro_t1.0_icd)+
  ggtitle("Combined model rank abundance", "ICD converted responses")
```

# Top diagnoses plots

```{r, fig.width = 12, fig.height = 8}
custom_labeler <- function(x, wrap_width=33) {
    x %>%
        str_replace("___.+$", "") %>%
        str_wrap(width = wrap_width)
}

custom_text_formatting <- list(
  theme(axis.text = element_text(size = 7, lineheight = 0.7), 
          strip.text = element_text(size = 7),
          axis.title = element_text(size = 9)),
  tidytext::scale_x_reordered(labels = ~custom_labeler(., wrap_width = 45))
)
```

```{r, fig.width = 16, fig.height = 8}
n_diag <- 25
sub <- "Original responses"
top_diagnosis_plot(df_gpt3.5, n_diag = n_diag)+ggtitle("ChatGPT 3.5", sub)
top_diagnosis_plot(df_gpt4.0, n_diag = n_diag)+ggtitle("ChatGPT 4.0", sub)
top_diagnosis_plot(df_claude3_haiku_t1.0, n_diag = n_diag)+ggtitle("Claude3 Haiku t1.0", sub)
top_diagnosis_plot(df_claude3_opus_t1.0, n_diag = n_diag)+ggtitle("Claude3 Opus t1.0", sub)
top_diagnosis_plot(df_gemini1.0_pro_t1.0, n_diag = n_diag)+ggtitle("Gemini 1.0 Pro", sub)
top_diagnosis_plot(df_gemini1.5_pro_t1.0, n_diag = n_diag)+ggtitle("Gemini 1.5 Pro", sub)
```

```{r, fig.width = 16, fig.height = 10}
n_diag <- 25
sub <- "ICD converted responses"
top_diagnosis_plot(df_gpt3.5_icd, n_diag = n_diag) + custom_text_formatting + ggtitle("ChatGPT 3.5 ICD", sub) 
top_diagnosis_plot(df_gpt4.0_icd, n_diag = n_diag)+ custom_text_formatting+ggtitle("ChatGPT 4.0 ICD", sub)
top_diagnosis_plot(df_claude3_haiku_t1.0_icd, n_diag = n_diag)+ custom_text_formatting+ggtitle("Claude3 Haiku t1.0 ICD", sub)
top_diagnosis_plot(df_claude3_opus_t1.0_icd, n_diag = n_diag)+ custom_text_formatting+ggtitle("Claude3 Opus t1.0 ICD", sub)
top_diagnosis_plot(df_gemini1.0_pro_t1.0_icd, n_diag = n_diag)+ custom_text_formatting+ggtitle("Gemini 1.0 Pro ICD", sub)
top_diagnosis_plot(df_gemini1.5_pro_t1.0_icd, n_diag = n_diag)+ custom_text_formatting+ggtitle("Gemini 1.5 Pro ICD", sub)
```

```{r, fig.width = 16, fig.height = 10}
multi_top_diagnosis_plot(distribution_vis = "points", wrap_width=45, n_diag = 25,
                         df_gpt3.5, df_gpt4.0, df_claude3_haiku_t1.0, 
                         df_claude3_opus_t1.0, df_gemini1.0_pro_t1.0,
                         df_gemini1.5_pro_t1.0)
```



```{r, fig.width = 16, fig.height = 10}
plt_diag_icd <- multi_top_diagnosis_plot(distribution_vis = "points", wrap_width = 33, n_diag = 15,
                         df_gpt3.5_icd, df_gpt4.0_icd, df_claude3_haiku_t1.0_icd, 
                         df_claude3_opus_t1.0_icd, df_gemini1.0_pro_t1.0_icd, 
                         df_gemini1.5_pro_t1.0_icd) +
  guides(size = guide_legend(override.aes = list(size = 2)))

plt_diag_icd
plt_diag_icd$data %>% 
  summarise(freq=mean(freq),.by=c("criteria","diagnosis")) %>% 
  arrange(criteria, desc(freq))
```

# Cumulative top frequency plots


```{r, fig.width=3, fig.height=3.5}
sub <- "Original responses"
cumulative_frequency_plot(df_gpt3.5)$plot+ggtitle("GPT3", sub)
cumulative_frequency_plot(df_gpt4.0)$plot+ggtitle("GPT4", sub)
cumulative_frequency_plot(df_claude3_haiku_t1.0)$plot+ggtitle("Claude3 Haiku", sub)
cumulative_frequency_plot(df_claude3_opus_t1.0)$plot+ggtitle("Claude3 Haiku", sub)
cumulative_frequency_plot(df_gemini1.0_pro_t1.0)$plot+ggtitle("Gemini Pro 1.0", sub)
cumulative_frequency_plot(df_gemini1.5_pro_t1.0)$plot+ggtitle("Gemini Pro 1.5", sub)
```

```{r, fig.width=3, fig.height=3.5}
sub <- "ICD converted responses"
cumulative_frequency_plot(df_gpt3.5_icd)$plot+ggtitle("GPT3 ICD", sub)
cumulative_frequency_plot(df_gpt4.0_icd)$plot+ggtitle("GPT4 ICD", sub)
cumulative_frequency_plot(df_claude3_haiku_t1.0_icd)$plot+ggtitle("Claude3 Haiku ICD", sub)
cumulative_frequency_plot(df_claude3_opus_t1.0_icd)$plot+ggtitle("Claude3 Haiku ICD", sub)
cumulative_frequency_plot(df_gemini1.0_pro_t1.0_icd)$plot+ggtitle("Gemini Pro 1.0 ICD", sub)
cumulative_frequency_plot(df_gemini1.5_pro_t1.0_icd)$plot+ggtitle("Gemini Pro 1.0 ICD", sub)
```


```{r, fig.width=4, fig.height=3.5}
plt_freq <- multi_cumulative_frequency_plot(
  n_diagnoses = 25,
  distribution_vis = "points",
  df_gpt3.5,
  df_gpt4.0,
  df_claude3_haiku_t1.0,
  df_claude3_opus_t1.0,
  df_gemini1.0_pro_t1.0,
  df_gemini1.5_pro_t1.0
) +
  ggtitle("Original responses")

plt_freq
plt_freq$data %>% summarise(freq = mean(total_frequency), .by = "criteria")
```

```{r, fig.width=4, fig.height=3.5}
plt_freq_icd <- multi_cumulative_frequency_plot(
  n_diagnoses = 25,
  distribution_vis = "points",
  df_gpt3.5_icd,
  df_gpt4.0_icd,
  df_claude3_haiku_t1.0_icd,
  df_claude3_opus_t1.0_icd,
  df_gemini1.0_pro_t1.0_icd,
  df_gemini1.5_pro_t1.0_icd
) +
  ggtitle("ICD converted responses")

plt_freq_icd
plt_freq_icd$data %>% summarise(freq = mean(total_frequency), .by = "criteria")
```

# Diagnosis rank table

```{r}
diagnosis_rank_table(df_gpt3.5, "mast |mastoc|anaphylaxis") %>% mutate(diagnosis = substr(diagnosis, 1, 60))
diagnosis_rank_table(df_gpt4.0, "mast |mastoc|anaphylaxis") %>% mutate(diagnosis = substr(diagnosis, 1, 60)) 
diagnosis_rank_table(df_claude3_haiku_t1.0, "mast |mastoc|anaphylaxis") %>% mutate(diagnosis = substr(diagnosis, 1, 60)) 
diagnosis_rank_table(df_claude3_opus_t1.0, "mast |mastoc|anaphylaxis") %>% mutate(diagnosis = substr(diagnosis, 1, 60)) 
diagnosis_rank_table(df_gemini1.0_pro_t1.0, "mast |mastoc|anaphylaxis") %>% mutate(diagnosis = substr(diagnosis, 1, 60)) 
diagnosis_rank_table(df_gemini1.5_pro_t1.0, "mast |mastoc|anaphylaxis") %>% mutate(diagnosis = substr(diagnosis, 1, 60)) 
```
```{r}
diagnosis_rank_table(df_gpt3.5_icd, "mast |mastoc|anaphylaxis") %>% mutate(diagnosis = substr(diagnosis, 1, 60)) 
diagnosis_rank_table(df_gpt4.0_icd, "mast |mastoc|anaphylaxis") %>% mutate(diagnosis = substr(diagnosis, 1, 60)) 
diagnosis_rank_table(df_claude3_haiku_t1.0_icd, "mast |mastoc|anaphylaxis") %>% mutate(diagnosis = substr(diagnosis, 1, 60)) 
diagnosis_rank_table(df_claude3_opus_t1.0_icd, "mast |mastoc|anaphylaxis") %>% mutate(diagnosis = substr(diagnosis, 1, 60)) 
diagnosis_rank_table(df_gemini1.0_pro_t1.0_icd, "mast |mastoc|anaphylaxis") %>% mutate(diagnosis = substr(diagnosis, 1, 60)) 
diagnosis_rank_table(df_gemini1.5_pro_t1.0_icd, "mast |mastoc|anaphylaxis") %>% mutate(diagnosis = substr(diagnosis, 1, 60)) 
```


```{r}
rank_table <-
  multi_diagnosis_rank_table(
    search_pattern = "T78\\.2 |D47\\.02 |D89\\.41 |D89\\.49 |D89\\.4 ",
    df_gpt3.5_icd,
    df_gpt4.0_icd,
    df_claude3_haiku_t1.0_icd,
    df_claude3_opus_t1.0_icd,
    df_gemini1.0_pro_t1.0_icd,
    df_gemini1.5_pro_t1.0_icd
  )
rank_table  
```

```{r}
rank_table %>% 
  flextable() %>% 
  width(width = 30) %>% 
  align(j = 2:3, align = "center", part = "all")
```

# Diversity

```{r, fig.width=4, fig.height=3.5}
multi_shannon_plot(
  distribution_vis = "points",
  wrap_width = 45,
  n_diag = 25,
  df_gpt3.5,
  df_gpt4.0,
  df_claude3_haiku_t1.0,
  df_claude3_opus_t1.0,
  df_gemini1.0_pro_t1.0,
  df_gemini1.5_pro_t1.0
)
```


```{r, fig.width=4, fig.height=3.5}
plt_div_icd <- multi_shannon_plot(
  distribution_vis = "points",
  wrap_width = 45,
  n_diag = 25,
  df_gpt3.5_icd,
  df_gpt4.0_icd,
  df_claude3_haiku_t1.0_icd,
  df_claude3_opus_t1.0_icd,
  df_gemini1.0_pro_t1.0_icd,
  df_gemini1.5_pro_t1.0_icd
)

plt_div_icd
plt_div_icd$data %>% summarise(shannon=mean(shannon),.by="criteria")
extract_ggpubr_pvalues(plt_div_icd)  
```

# Similarity

```{r, fig.width=4.25, fig.height=3.5}
diagnosis_similarity_heatmap(df_gpt3.5, method = "bray")
diagnosis_similarity_heatmap(df_gpt4.0, method = "bray")
diagnosis_similarity_heatmap(df_claude3_haiku_t1.0, method = "bray")
diagnosis_similarity_heatmap(df_claude3_opus_t1.0, method = "bray")
diagnosis_similarity_heatmap(df_gemini1.0_pro_t1.0, method = "bray")
diagnosis_similarity_heatmap(df_gemini1.5_pro_t1.0, method = "bray")
```
```{r, fig.width=4.25, fig.height=3.5}
diagnosis_similarity_heatmap(df_gpt3.5_icd, method = "bray")
diagnosis_similarity_heatmap(df_gpt4.0_icd, method = "bray")
diagnosis_similarity_heatmap(df_claude3_haiku_t1.0_icd, method = "bray")
diagnosis_similarity_heatmap(df_claude3_opus_t1.0_icd, method = "bray")
diagnosis_similarity_heatmap(df_gemini1.0_pro_t1.0_icd, method = "bray")
diagnosis_similarity_heatmap(df_gemini1.5_pro_t1.0_icd, method = "bray")
```

```{r, fig.width=4.25, fig.height=3.5}
multi_diagnosis_similarity_heatmap(
  method = "bray",
  show_dend = F,
  label_size = 6,
  title_size = 9,
  df_gpt3.5,
  df_gpt4.0,
  df_claude3_haiku_t1.0,
  df_claude3_opus_t1.0,
  df_gemini1.0_pro_t1.0,
  df_gemini1.5_pro_t1.0
)
```

```{r, fig.width=4.25, fig.height=3.5}
multi_diagnosis_similarity_heatmap(
  method = "bray",
  show_dend = F,
  label_size = 6,
  title_size = 9,
  df_gpt3.5_icd,
  df_gpt4.0_icd,
  df_claude3_haiku_t1.0_icd,
  df_claude3_opus_t1.0_icd,
  df_gemini1.0_pro_t1.0_icd,
  df_gemini1.5_pro_t1.0_icd
)
```
- Bray-Curtis similarity measures the similarity of a given diagnostic criteria’s set of alternative diagnoses along with their frequencies.
- This demonstrates that SLE criteria results in a very similar set and frequency of diagnoses, while the diagnoses associated with two MCAS criteria are as different from each other as they are from those generated by the criteria of other conditions.

### PCA

```{r, fig.width=4.25, fig.height=3.5}
diagnosis_pca_plot(df_gpt3.5) + ggtitle("GPT3")
diagnosis_pca_plot(df_gpt4.0) + ggtitle("GPT4")
diagnosis_pca_plot(df_claude3_haiku_t1.0) + ggtitle("Claude Haiku")
diagnosis_pca_plot(df_claude3_opus_t1.0) + ggtitle("Claude Opus")
diagnosis_pca_plot(df_gemini1.0_pro_t1.0) + ggtitle("Gemini")
diagnosis_pca_plot(df_gemini1.5_pro_t1.0) + ggtitle("Gemini")
```


```{r}
df <- listN(df_gpt3.5_icd, df_gpt4.0_icd, df_claude3_haiku_t1.0_icd, df_claude3_opus_t1.0_icd, df_gemini1.0_pro_t1.0_icd, df_gemini1.5_pro_t1.0_icd) %>% 
  mapply(function(x,y) {mutate(x, model=y)}, ., names(.), SIMPLIFY = F) %>% 
  bind_rows() %>% 
  count(model, criteria, diagnosis) %>% 
  pivot_wider(names_from = "diagnosis", values_from = "n", values_fill = 0) %>% 
  unite(id, model, criteria, sep = "__") %>% 
  column_to_rownames("id") %>% 
  prcomp(scale. = F)

as.data.frame(df$x) %>% 
    rownames_to_column("id") %>% 
  separate(id, into = c("model", "criteria"), sep = "__") %>% 
  format_criteria() %>% 
  format_models() %>% 
  ggplot(aes(x = PC1, y = PC2, color = criteria))+
    geom_point()+
    # ggrepel::geom_label_repel() +
    theme_bw() +
  scale_color_brewer(palette = "Dark2")
```

# Precision

- Precision represents how similar each iteration of a 10-point differential diagnosis is with all other differential diagnoses from the same set of criteria. 
- I.e. how reproducible the 10-point differential diagnosis is for each criteria
- Measured by obtaining the Bray-Curtis similarity values between all iterations within a criteria

```{r, eval=F}
# Script for calculating all Bray-Curtis similarity values within a criteria
# Found in source(here("scripts/diversity_analysis/calculate_precision.R"))
# Calculate precision
library(here)
source(here("utils/data_processing.R"))

models <- list.files(here("data/processed_diagnoses"), pattern = "gz$") %>% 
  str_split("diagnoses_|_icd|.csv") %>% 
  sapply(., function(x) x[2]) %>% 
  unique()

use_icd <- TRUE

if (use_icd){models <- str_glue("{models}_icd")}

for (m in models){
  print(sprintf("READING IN DATA FOR: %s", m))
  read_path <- sprintf("data/processed_diagnoses/diagnoses_%s.csv.gz", m)
  df <- read_csv(here(read_path))
  
  print(sprintf("CALCULATING PRECISION FOR: %s", m))
  df <- calculate_precision(df)
  
  print(sprintf("WRITING PRECISION DATA FOR: %s", m))
  out_path <- sprintf("data/diversity_analysis/diagnosis_precision_%s.csv.gz", m)
  write_csv(df, here(out_path))
}
```

```{r, fig.width=4, fig.height=3.5}
precision_dist_to_sim <- function(df){
  df %>% 
    mutate(
      mean = 1-mean,
      max = 1-min,
      min = 1-max
    )
}

plt_precision_icd <- read_csv(here("data/diversity_analysis/compiled_icd_diagnosis_precision.csv")) %>% 
  precision_dist_to_sim() %>% 
  format_criteria() %>% 
  format_models() %>%
  filter(model != "Gemini 1.5 Flash") %>% 
  ggplot(aes(x = criteria, y = mean))+
  theme_bw()+
  theme(axis.text.x = element_text(angle= 45, hjust = 1))+
  labs(x="", y = "Average Bray-Curtis Similarity") +
  ggpubr::geom_pwc(aes(group = criteria), method = "wilcox.test", p.adjust.method = "BH", hide.ns = T, label = "p.adj.signif", bracket.nudge.y = 0.3, vjust = 0.6, step.increase = 0.14, tip.length = 0.02) +
  labs(color=NULL)+
  scale_color_brewer(palette = "Dark2") +
  plot_selector("points")

plt_precision_icd
plt_precision_icd$data %>% summarise(mean = mean(mean), .by="criteria") 
extract_ggpubr_pvalues(plt_precision_icd) 
```


# iNEXT

```{r, fig.width=12, fig.height=4}
inext_plots <- function(inext_obj){
  for (i in 1:3){
    plt <- iNEXT::ggiNEXT(inext_obj, type=i, facet.var="Assemblage", color.var="Assemblage") +
      theme_classic() + 
      scale_color_brewer(palette = "Set1") +
      theme(axis.text.x = element_text(angle = 90))+
      scale_color_brewer(palette = "Dark2")
    print(plt)
  }
}

readRDS(here("data/diversity_analysis/mcas_iNEXT_gpt4_e250000.RDS")) %>% inext_plots()
readRDS(here("data/diversity_analysis/mcas_iNEXT_dropSingle_gpt4_e200000.RDS")) %>% inext_plots()
readRDS(here("data/diversity_analysis/mcas_iNEXT_dropSingle_psuedoMinus_gpt4_e200000.RDS")) %>% inext_plots()
```



```{r, fig.width=7.4, fig.height=6.5}
# custom_labeler <- function(x, wrap_width=33) {
#     x %>%
#         str_replace("___.+$", "") %>%
#         str_wrap(width = wrap_width)
# }
```

# Final plot

### Version 1

```{r, fig.width=7.5, fig.height=8.5, message = F}
n_diagnoses_bar <- 10
n_diagnoses_abundance <- 50
n_diagnoses_cumulative <- 50

title_size <- 9
label_size <- 6
legend_x_pad <- 4
legend_y_pad <- 2

apply_text_formatting <- list(theme(
  axis.text = element_text(size = label_size),
  axis.title = element_text(size = title_size),
  legend.text = element_text(size = label_size),
  strip.text = element_text(size = label_size+1),
  legend.key.height = unit(0.4, 'cm'),
  legend.box.background = element_rect(color = "black", size = 1),
  legend.margin = margin(t = legend_y_pad, r = legend_x_pad, b = legend_y_pad, l = legend_x_pad*1.1),
  legend.spacing.x = unit(0, 'cm'),                           # Horizontal spacing between legend items
  # legend.spacing.y = unit(0, 'cm'),
  # legend.box.spacing = unit(0, "cm")
  ))

strip_margin <- 1
strip_formatting <- list(theme(
  strip.text.x = element_text(margin = margin(t = strip_margin, r = strip_margin, b = strip_margin, l = strip_margin)), 
  strip.text.y = element_text(margin = margin(t = strip_margin, r = strip_margin, b = strip_margin, l = strip_margin))
  # strip.background = element_rect(margin = margin(t = strip_margin, r = strip_margin, b = strip_margin, l = strip_margin))
))

plt_diags <-
  multi_top_diagnosis_plot(
    distribution_vis = "points",
    wrap_width = 58,
    n_diag = n_diagnoses_bar,
    df_gpt3.5_icd,
    df_gpt4.0_icd,
    df_claude3_haiku_t1.0_icd,
    df_claude3_opus_t1.0_icd,
    df_gemini1.0_pro_t1.0_icd,
    df_gemini1.5_flash_t1.0_icd
  ) +
  theme(legend.position = "bottom", legend.direction = "horizontal") +
  apply_text_formatting +
  theme(axis.text.y = element_text(size = 6.5)) +
  strip_formatting +
  # theme(legend.position = c(-1,0))+
  theme(panel.spacing = unit(0, "lines")) +
  guides(color = guide_legend(override.aes = list(size = 2)))  # Increase the point size in the legend)
  

plt_rank <-
  multi_ranked_abundance_plot(
    df_gpt3.5_icd,
    df_gpt4.0_icd,
    df_claude3_haiku_t1.0_icd,
    df_claude3_opus_t1.0_icd,
    df_gemini1.0_pro_t1.0_icd,
    df_gemini1.5_pro_t1.0_icd
  ) +
  theme(legend.position = "bottom", legend.direction = "horizontal") +
  apply_text_formatting +
  guides(color = guide_legend(ncol = 2))

plt_cumulative <- multi_cumulative_frequency_plot(
  n_diagnoses = n_diagnoses_cumulative,
  distribution_vis = "points",
  df_gpt3.5_icd,
  df_gpt4.0_icd,
  df_claude3_haiku_t1.0_icd,
  df_claude3_opus_t1.0_icd,
  df_gemini1.0_pro_t1.0_icd
) +
  theme(legend.position = "bottom", legend.direction = "horizontal") +
  apply_text_formatting +
  guides(color = guide_legend(ncol = 2)) +
  labs(y = "Combined frequency\nof top 50 diagnoses", x = NULL)

plt_shannon <- multi_shannon_plot(
  distribution_vis = "points",
  wrap_width = 45,
  n_diag = 25,
  df_gpt3.5_icd,
  df_gpt4.0_icd,
  df_claude3_haiku_t1.0_icd,
  df_claude3_opus_t1.0_icd,
  df_gemini1.0_pro_t1.0_icd,
  df_gemini1.5_pro_t1.0_icd
) +
  apply_text_formatting +
  theme(legend.position = "bottom", legend.direction = "horizontal") +
  guides(color = guide_legend(ncol = 2))

plt_precision <- read_csv(here("data/diversity_analysis/compiled_icd_diagnosis_precision.csv")) %>% 
  precision_dist_to_sim() %>% 
  format_criteria() %>% 
  format_models() %>%
  filter(model != "Gemini 1.5 Flash") %>% 
  ggplot(aes(x = criteria, y = mean))+
  theme_bw()+
  theme(axis.text.x = element_text(angle= 45, hjust = 1))+
  labs(x="", y = "Average Bray-Curtis\nSimilarity") +
  ggpubr::geom_pwc(aes(group = criteria), method = "wilcox.test", p.adjust.method = "BH", hide.ns = T, label = "p.adj.signif", bracket.nudge.y = 0.3, vjust = 0.6, step.increase = 0.14, tip.length = 0.02) +
  labs(color=NULL)+
  scale_color_brewer(palette = "Dark2") +
  plot_selector("points") +
  apply_text_formatting +
  theme(legend.position = "bottom", legend.direction = "horizontal") +
  guides(color = guide_legend(ncol = 2))

full_plt <- plot_grid(
  
  ###
  plt_diags,
  ###
  NULL,
  plot_grid(
      plt_rank,
      plt_cumulative,
      plt_shannon,
      plt_precision,
      nrow = 1, 
      axis = 'tb',
      align = 'h',
      rel_widths = c(1, 0.7, 0.7, 0.7),
      labels = c(LETTERS[2:5]),
      vjust = 0.2
    ),
  ncol = 1,
  rel_heights = c(1.2, 0.05, 0.65),
  labels = c("A","","")
)  

full_plt
```

### Version 2
```{r, fig.width=7.5, fig.height=8.5, message = F}
n_diagnoses_bar <- 10
n_diagnoses_abundance <- 50
n_diagnoses_cumulative <- 50

title_size <- 9
label_size <- 6
legend_x_pad <- 4
legend_y_pad <- 2

apply_text_formatting <- list(theme(
  axis.text = element_text(size = label_size),
  axis.title = element_text(size = title_size),
  legend.text = element_text(size = label_size),
  strip.text = element_text(size = label_size+1),
  legend.key.height = unit(0.4, 'cm'),
  legend.box.background = element_rect(color = "black", size = 1),
  legend.margin = margin(t = legend_y_pad, r = legend_x_pad, b = legend_y_pad, l = legend_x_pad*1.1),
  legend.spacing.x = unit(0, 'cm'),                           # Horizontal spacing between legend items
  # legend.spacing.y = unit(0, 'cm'),
  # legend.box.spacing = unit(0, "cm")
  ))

strip_margin <- 1
strip_formatting <- list(theme(
  strip.text.x = element_text(margin = margin(t = strip_margin, r = strip_margin, b = strip_margin, l = strip_margin)), 
  strip.text.y = element_text(margin = margin(t = strip_margin, r = strip_margin, b = strip_margin, l = strip_margin))
  # strip.background = element_rect(margin = margin(t = strip_margin, r = strip_margin, b = strip_margin, l = strip_margin))
))

plt_diags <-
  multi_top_diagnosis_plot(
    distribution_vis = "points",
    wrap_width = 58,
    n_diag = n_diagnoses_bar,
    df_gpt3.5_icd,
    df_gpt4.0_icd,
    df_claude3_haiku_t1.0_icd,
    df_claude3_opus_t1.0_icd,
    df_gemini1.0_pro_t1.0_icd,
    df_gemini1.5_flash_t1.0_icd
  ) +
  theme(legend.position = "bottom", legend.direction = "horizontal") +
  apply_text_formatting +
  theme(axis.text.y = element_text(size = 6.5)) +
  strip_formatting +
  # theme(legend.position = c(-1,0))+
  theme(panel.spacing = unit(0, "lines")) +
  guides(color = guide_legend(override.aes = list(size = 2), nrow = 1))  # Increase the point size in the legend)
  

plt_rank <-
  multi_ranked_abundance_plot(
    df_gpt3.5_icd,
    df_gpt4.0_icd,
    df_claude3_haiku_t1.0_icd,
    df_claude3_opus_t1.0_icd,
    df_gemini1.0_pro_t1.0_icd,
    df_gemini1.5_pro_t1.0_icd
  ) +
  theme(legend.position = c(0.7,0.7))+
  # theme(legend.position = "bottom", legend.direction = "horizontal") +
  apply_text_formatting +
  guides(color = guide_legend(ncol = 1)) +
  labs(color = NULL)

plt_cumulative <- multi_cumulative_frequency_plot(
  n_diagnoses = n_diagnoses_cumulative,
  distribution_vis = "points",
  df_gpt3.5_icd,
  df_gpt4.0_icd,
  df_claude3_haiku_t1.0_icd,
  df_claude3_opus_t1.0_icd,
  df_gemini1.0_pro_t1.0_icd
) +
  theme(legend.position = "bottom", legend.direction = "horizontal") +
  apply_text_formatting +
  guides(color = guide_legend(ncol = 2)) +
  labs(y = "Combined frequency\nof top 50 diagnoses", x = NULL)

plt_shannon <- multi_shannon_plot(
  distribution_vis = "points",
  wrap_width = 45,
  n_diag = 25,
  df_gpt3.5_icd,
  df_gpt4.0_icd,
  df_claude3_haiku_t1.0_icd,
  df_claude3_opus_t1.0_icd,
  df_gemini1.0_pro_t1.0_icd,
  df_gemini1.5_pro_t1.0_icd
) +
  apply_text_formatting +
  theme(legend.position = "bottom", legend.direction = "horizontal") +
  guides(color = guide_legend(ncol = 2))

plt_precision <- read_csv(here("data/diversity_analysis/compiled_icd_diagnosis_precision.csv")) %>% 
  precision_dist_to_sim() %>% 
  format_criteria() %>% 
  format_models() %>%
  filter(model != "Gemini 1.5 Flash") %>% 
  ggplot(aes(x = criteria, y = mean))+
  theme_bw()+
  theme(axis.text.x = element_text(angle= 45, hjust = 1))+
  labs(x="", y = "Average Bray-Curtis\nSimilarity") +
  ggpubr::geom_pwc(aes(group = criteria), method = "wilcox.test", p.adjust.method = "BH", hide.ns = T, label = "p.adj.signif", bracket.nudge.y = 0.3, vjust = 0.6, step.increase = 0.14, tip.length = 0.02) +
  labs(color=NULL)+
  scale_color_brewer(palette = "Dark2") +
  plot_selector("points") +
  apply_text_formatting +
  theme(legend.position = "bottom", legend.direction = "horizontal") +
  guides(color = guide_legend(ncol = 2))

full_plt <- plot_grid(
  
  ###
  plt_diags,
  ###
  NULL,
  plot_grid(
      plt_rank,
      plot_grid(
        plot_grid(
          plt_shannon+ theme(legend.position="none"),
          plt_precision+ theme(legend.position="none"),
          nrow = 1,
          axis = 'tb',
          align = 'h'
        ),
        get_legend(plt_shannon+ guides(color = guide_legend(row = 1))),
        ncol = 1,
        rel_heights = c(1,0.1)
      ),
      nrow = 1, 
      rel_widths = c(1,1),
      # labels = c(LETTERS[2:5]),
      vjust = 0.2
    ),
  ncol = 1,
  rel_heights = c(1.2, 0.05, 0.65),
  labels = c("A","","")
)  

full_plt
```

### Version 3
```{r, fig.width=7.5, fig.height=8.5, message = F}
n_diagnoses_bar <- 10
n_diagnoses_abundance <- 50
n_diagnoses_cumulative <- 50

title_size <- 9
label_size <- 6
legend_x_pad <- 4
legend_y_pad <- 2

apply_text_formatting <- list(theme(
  axis.text = element_text(size = label_size),
  axis.title = element_text(size = title_size),
  legend.text = element_text(size = label_size),
  strip.text = element_text(size = label_size+1),
  legend.key.height = unit(0.4, 'cm'),
  legend.box.background = element_rect(color = "black", size = 1),
  legend.margin = margin(t = legend_y_pad, r = legend_x_pad, b = legend_y_pad, l = legend_x_pad*1.1),
  legend.spacing.x = unit(0, 'cm'),                           # Horizontal spacing between legend items
  # legend.spacing.y = unit(0, 'cm'),
  # legend.box.spacing = unit(0, "cm")
  ))

strip_margin <- 1
strip_formatting <- list(theme(
  strip.text.x = element_text(margin = margin(t = strip_margin, r = strip_margin, b = strip_margin, l = strip_margin)), 
  strip.text.y = element_text(margin = margin(t = strip_margin, r = strip_margin, b = strip_margin, l = strip_margin))
  # strip.background = element_rect(margin = margin(t = strip_margin, r = strip_margin, b = strip_margin, l = strip_margin))
))

plt_diags <-
  multi_top_diagnosis_plot(
    distribution_vis = "points",
    wrap_width = 58,
    n_diag = n_diagnoses_bar,
    df_gpt3.5_icd,
    df_gpt4.0_icd,
    df_claude3_haiku_t1.0_icd,
    df_claude3_opus_t1.0_icd,
    df_gemini1.0_pro_t1.0_icd,
    df_gemini1.5_flash_t1.0_icd
  ) +
  theme(legend.position = "bottom", legend.direction = "horizontal") +
  apply_text_formatting +
  theme(axis.text.y = element_text(size = 6.5)) +
  strip_formatting +
  # theme(legend.position = c(-1,0))+
  theme(panel.spacing = unit(0, "lines")) +
  guides(color = guide_legend(override.aes = list(size = 2), nrow = 1))  # Increase the point size in the legend)
  

plt_rank <-
  multi_ranked_abundance_plot(
    df_gpt3.5_icd,
    df_gpt4.0_icd,
    df_claude3_haiku_t1.0_icd,
    df_claude3_opus_t1.0_icd,
    df_gemini1.0_pro_t1.0_icd,
    df_gemini1.5_pro_t1.0_icd
  ) +
  theme(legend.position = c(0.7,0.7))+
  # theme(legend.position = "bottom", legend.direction = "horizontal") +
  apply_text_formatting +
  guides(color = guide_legend(ncol = 1)) +
  labs(color = NULL)

plt_cumulative <- multi_cumulative_frequency_plot(
  n_diagnoses = n_diagnoses_cumulative,
  distribution_vis = "points",
  df_gpt3.5_icd,
  df_gpt4.0_icd,
  df_claude3_haiku_t1.0_icd,
  df_claude3_opus_t1.0_icd,
  df_gemini1.0_pro_t1.0_icd
) +
  theme(legend.position = "bottom", legend.direction = "horizontal") +
  apply_text_formatting +
  guides(color = guide_legend(ncol = 2)) +
  labs(y = "Combined frequency\nof top 50 diagnoses", x = NULL)

plt_shannon <- multi_shannon_plot(
  distribution_vis = "points",
  wrap_width = 45,
  n_diag = 25,
  df_gpt3.5_icd,
  df_gpt4.0_icd,
  df_claude3_haiku_t1.0_icd,
  df_claude3_opus_t1.0_icd,
  df_gemini1.0_pro_t1.0_icd,
  df_gemini1.5_pro_t1.0_icd
) +
  apply_text_formatting +
  theme(legend.position = "bottom", legend.direction = "horizontal") +
  guides(color = guide_legend(ncol = 2))

plt_precision <- read_csv(here("data/diversity_analysis/compiled_icd_diagnosis_precision.csv")) %>% 
  precision_dist_to_sim() %>% 
  format_criteria() %>% 
  format_models() %>%
  filter(model != "Gemini 1.5 Flash") %>% 
  ggplot(aes(x = criteria, y = mean))+
  theme_bw()+
  theme(axis.text.x = element_text(angle= 45, hjust = 1))+
  labs(x="", y = "Mean Bray-Curtis Similarity") +
  ggpubr::geom_pwc(aes(group = criteria), method = "wilcox.test", p.adjust.method = "BH", hide.ns = T, label = "p.adj.signif", bracket.nudge.y = 0.3, vjust = 0.6, step.increase = 0.14, tip.length = 0.02) +
  labs(color=NULL)+
  scale_color_brewer(palette = "Dark2") +
  plot_selector("points") +
  apply_text_formatting +
  theme(legend.position = "bottom", legend.direction = "horizontal") +
  guides(color = guide_legend(ncol = 2))

plt_similarity <- multi_diagnosis_similarity_heatmap(
  method = "bray",
  show_dend = F,
  legend_label = "Bray-Curtis similarity",
  legend_direction = "horizontal",
  label_size = 6,
  title_size = 9,
  df_gpt3.5_icd,
  df_gpt4.0_icd,
  df_claude3_haiku_t1.0_icd,
  df_claude3_opus_t1.0_icd,
  df_gemini1.0_pro_t1.0_icd,
  df_gemini1.5_pro_t1.0_icd
)

full_plt <- plot_grid(
  
  ###
  plt_diags,
  ###
  NULL,
  plot_grid(
      grid::grid.grabExpr(ComplexHeatmap::draw(plt_similarity, heatmap_legend_side = 'bottom')),
      plt_rank,
      # NULL,
      plot_grid(
        plot_grid(
          plt_shannon+ theme(legend.position="none"),
          plt_precision+ theme(legend.position="none"),
          nrow = 1,
          axis = 'tb',
          align = 'h',
          labels = c(LETTERS[4:5])
        ),
        get_legend(plt_shannon+ guides(color = guide_legend(row = 1))),
        ncol = 1,
        rel_heights = c(1,0.1)
      ),
      nrow = 1, 
      # rel_widths = c(1, 0.01, 0.8, 0.9),
      rel_widths = c(0.8, 1,0.9),
      labels = c(LETTERS[2:3]),
      vjust = 0.2
    ),
  ncol = 1,
  rel_heights = c(1.2, 0.05, 0.65),
  labels = c("A","","")
)  

full_plt
```


Things to fix
- Legend position for C-E
- Legend width for B
- Move legend for A to the left of "Frequency?"
- Rank plot line weight

```{r}
ggsave(plot=full_plt,filename=here("figures/3_diagnosis_diversity.pdf"), width = 7.5, height = 7.5)
```

set_table_properties(opts_pdf = list(tabcolsep = 0))
```{r}
set_flextable_defaults(fonts_ignore=TRUE)

multi_diagnosis_rank_table(search_pattern = "T78\\.2 |D47\\.02 |D89\\.41 |D89\\.49 |D89\\.4 ",
                                         df_gpt3.5_icd, df_gpt4.0_icd, df_claude3_haiku_t1.0_icd, df_claude3_opus_t1.0_icd, df_gemini1.0_pro_t1.0_icd) %>% 
  flextable() %>% 
  width(width = 2) %>% 
  fontsize(size = 9) %>% 
  fontsize(size = 10, part = "header") %>% 
  padding(padding = 0) %>% 
  align(j = 2:3, align = "center", part = "all") %>% 
  set_table_properties(opts_pdf = list(arraystretch = 1.25)) %>% 
  {print(., preview = "pdf");.}
```


